/*

 * Licensed to the Apache Software Foundation (ASF) under one

 * or more contributor license agreements.  See the NOTICE file

 * distributed with this work for additional information

 * regarding copyright ownership.  The ASF licenses this file

 * to you under the Apache License, Version 2.0 (the

 * "License"); you may not use this file except in compliance

 * with the License.  You may obtain a copy of the License at

 *

 *     http://www.apache.org/licenses/LICENSE-2.0

 *

 * Unless required by applicable law or agreed to in writing, software

 * distributed under the License is distributed on an "AS IS" BASIS,

 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

 * See the License for the specific language governing permissions and

 * limitations under the License.

 */

package com.bff.gaia.unified.sdk.io.gcp.pubsub;



import io.grpc.Status;

import io.grpc.StatusRuntimeException;

import com.bff.gaia.unified.sdk.coders.Coder;

import com.bff.gaia.unified.sdk.io.gcp.pubsub.PubsubClient.SubscriptionPath;

import com.bff.gaia.unified.sdk.io.gcp.pubsub.PubsubClient.TopicPath;

import com.bff.gaia.unified.sdk.state.BagState;

import com.bff.gaia.unified.sdk.state.StateSpec;

import com.bff.gaia.unified.sdk.state.StateSpecs;

import com.bff.gaia.unified.sdk.testing.TestPipeline;

import com.bff.gaia.unified.sdk.testing.TestPipelineOptions;

import com.bff.gaia.unified.sdk.transforms.*;

import com.bff.gaia.unified.sdk.transforms.Create;

import com.bff.gaia.unified.sdk.transforms.DoFn;

import com.bff.gaia.unified.sdk.transforms.PTransform;

import com.bff.gaia.unified.sdk.transforms.ParDo;

import com.bff.gaia.unified.sdk.transforms.SerializableFunction;

import com.bff.gaia.unified.sdk.transforms.WithKeys;

import com.bff.gaia.unified.sdk.transforms.windowing.GlobalWindows;

import com.bff.gaia.unified.sdk.transforms.windowing.Window;

import com.bff.gaia.unified.sdk.values.*;

import com.bff.gaia.unified.vendor.guava.com.google.common.base.Supplier;

import com.bff.gaia.unified.vendor.guava.com.google.common.base.Suppliers;

import com.bff.gaia.unified.vendor.guava.com.google.common.collect.ImmutableSet;

import com.bff.gaia.unified.sdk.values.KV;

import com.bff.gaia.unified.sdk.values.PBegin;

import com.bff.gaia.unified.sdk.values.PCollection;

import com.bff.gaia.unified.sdk.values.PDone;

import com.bff.gaia.unified.sdk.values.POutput;

import org.joda.time.DateTime;

import org.joda.time.Duration;

import org.junit.rules.TestRule;

import org.junit.runner.Description;

import org.junit.runners.model.Statement;

import org.slf4j.Logger;

import org.slf4j.LoggerFactory;



import javax.annotation.Nullable;

import java.io.IOException;

import java.util.List;

import java.util.Set;

import java.util.concurrent.ThreadLocalRandom;



import static java.nio.charset.StandardCharsets.UTF_8;

import static java.util.stream.Collectors.toList;

import static com.bff.gaia.unified.vendor.guava.com.google.common.base.Preconditions.checkState;



/**

 * Test rule which observes elements of the {@link PCollection} and checks whether they match the

 * success criteria.

 *

 * <p>Uses a random temporary Pubsub topic for synchronization.

 */

public class TestPubsubSignal implements TestRule {

  private static final Logger LOG = LoggerFactory.getLogger(TestPubsubSignal.class);

  private static final String RESULT_TOPIC_NAME = "result";

  private static final String RESULT_SUCCESS_MESSAGE = "SUCCESS";

  private static final String START_TOPIC_NAME = "start";

  private static final String START_SIGNAL_MESSAGE = "START SIGNAL";



  private static final String NO_ID_ATTRIBUTE = null;

  private static final String NO_TIMESTAMP_ATTRIBUTE = null;



  PubsubClient pubsub;

  private TestPubsubOptions pipelineOptions;

  private @Nullable TopicPath resultTopicPath = null;

  private @Nullable TopicPath startTopicPath = null;



  /**

   * Creates an instance of this rule.

   *

   * <p>Loads GCP configuration from {@link TestPipelineOptions}.

   */

  public static TestPubsubSignal create() {

    TestPubsubOptions options = TestPipeline.testingPipelineOptions().as(TestPubsubOptions.class);

    return new TestPubsubSignal(options);

  }



  private TestPubsubSignal(TestPubsubOptions pipelineOptions) {

    this.pipelineOptions = pipelineOptions;

  }



  @Override

  public Statement apply(Statement base, Description description) {

    return new Statement() {

      @Override

      public void evaluate() throws Throwable {

        if (TestPubsubSignal.this.pubsub != null) {

          throw new AssertionError(

              "Pubsub client was not shutdown in previous test. "

                  + "Topic path is'"

                  + resultTopicPath

                  + "'. "

                  + "Current test: "

                  + description.getDisplayName());

        }



        try {

          initializePubsub(description);

          base.evaluate();

        } finally {

          tearDown();

        }

      }

    };

  }



  private void initializePubsub(Description description) throws IOException {

    pubsub =

        PubsubGrpcClient.FACTORY.newClient(

            NO_TIMESTAMP_ATTRIBUTE, NO_ID_ATTRIBUTE, pipelineOptions);



    // Example topic name:

    //    integ-test-TestClassName-testMethodName-2018-12-11-23-32-333-<random-long>-result

    TopicPath resultTopicPathTmp =

        PubsubClient.topicPathFromName(

            pipelineOptions.getProject(), TestPubsub.createTopicName(description, RESULT_TOPIC_NAME));

    TopicPath startTopicPathTmp =

        PubsubClient.topicPathFromName(

            pipelineOptions.getProject(), TestPubsub.createTopicName(description, START_TOPIC_NAME));



    pubsub.createTopic(resultTopicPathTmp);

    pubsub.createTopic(startTopicPathTmp);



    // Set these after successful creation; this signals that they need teardown

    resultTopicPath = resultTopicPathTmp;

    startTopicPath = startTopicPathTmp;

  }



  private void tearDown() throws IOException {

    if (pubsub == null) {

      return;

    }



    try {

      if (resultTopicPath != null) {

        pubsub.deleteTopic(resultTopicPath);

      }

    } finally {

      pubsub.close();

      pubsub = null;

      resultTopicPath = null;

    }

  }



  /** Outputs a message that the pipeline has started. */

  public PTransform<PBegin, PDone> signalStart() {

    return new PublishStart(startTopicPath);

  }



  /**

   * Outputs a success message when {@code successPredicate} is evaluated to true.

   *

   * <p>{@code successPredicate} is a {@link SerializableFunction} that accepts a set of currently

   * captured events and returns true when the set satisfies the success criteria.

   *

   * <p>If {@code successPredicate} is evaluated to false, then it will be re-evaluated when next

   * event becomes available.

   *

   * <p>If {@code successPredicate} is evaluated to true, then a success will be signaled and {@link

   * #waitForSuccess(Duration)} will unblock.

   *

   * <p>If {@code successPredicate} throws, then failure will be signaled and {@link

   * #waitForSuccess(Duration)} will unblock.

   */

  public <T> PTransform<PCollection<? extends T>, POutput> signalSuccessWhen(

      Coder<T> coder,

      SerializableFunction<T, String> formatter,

      SerializableFunction<Set<T>, Boolean> successPredicate) {



    return new PublishSuccessWhen<>(coder, formatter, successPredicate, resultTopicPath);

  }



  /**

   * Invocation of {@link #signalSuccessWhen(Coder, SerializableFunction, SerializableFunction)}

   * with {@link Object#toString} as the formatter.

   */

  public <T> PTransform<PCollection<? extends T>, POutput> signalSuccessWhen(

	  Coder<T> coder, SerializableFunction<Set<T>, Boolean> successPredicate) {



    return signalSuccessWhen(coder, T::toString, successPredicate);

  }



  /**

   * Future that waits for a start signal for {@code duration}.

   *

   * <p>This future must be created before running the pipeline. A subscription must exist prior to

   * the start signal being published, which occurs immediately upon pipeline startup.

   */

  public Supplier<Void> waitForStart(Duration duration) throws IOException {

    SubscriptionPath startSubscriptionPath =

        PubsubClient.subscriptionPathFromName(

            pipelineOptions.getProject(),

            "start-subscription-" + String.valueOf(ThreadLocalRandom.current().nextLong()));



    pubsub.createSubscription(

        startTopicPath, startSubscriptionPath, (int) duration.getStandardSeconds());



    return Suppliers.memoize(

        () -> {

          try {

            String result = pollForResultForDuration(startSubscriptionPath, duration);

            checkState(START_SIGNAL_MESSAGE.equals(result));

            return null;

          } catch (IOException e) {

            throw new RuntimeException(e);

          }

        });

  }



  /** Wait for a success signal for {@code duration}. */

  public void waitForSuccess(Duration duration) throws IOException {

    SubscriptionPath resultSubscriptionPath =

        PubsubClient.subscriptionPathFromName(

            pipelineOptions.getProject(),

            "result-subscription-" + String.valueOf(ThreadLocalRandom.current().nextLong()));



    pubsub.createSubscription(

        resultTopicPath, resultSubscriptionPath, (int) duration.getStandardSeconds());



    String result = pollForResultForDuration(resultSubscriptionPath, duration);



    if (!RESULT_SUCCESS_MESSAGE.equals(result)) {

      throw new AssertionError(result);

    }

  }



  private String pollForResultForDuration(

      SubscriptionPath signalSubscriptionPath, Duration duration) throws IOException {



    List<PubsubClient.IncomingMessage> signal = null;

    DateTime endPolling = DateTime.now().plus(duration.getMillis());



    do {

      try {

        signal = pubsub.pull(DateTime.now().getMillis(), signalSubscriptionPath, 1, false);

        pubsub.acknowledge(

            signalSubscriptionPath, signal.stream().map(m -> m.ackId).collect(toList()));

        break;

      } catch (StatusRuntimeException e) {

        if (!Status.DEADLINE_EXCEEDED.equals(e.getStatus())) {

          LOG.warn(

              "(Will retry) Error while polling {} for signal: {}",

              signalSubscriptionPath,

              e.getStatus());

        }

        sleep(500);

      }

    } while (DateTime.now().isBefore(endPolling));



    if (signal == null) {

      throw new AssertionError(

          String.format(

              "Did not receive signal on %s in %ss",

              signalSubscriptionPath, duration.getStandardSeconds()));

    }



    return new String(signal.get(0).elementBytes, UTF_8);

  }



  private void sleep(long t) {

    try {

      Thread.sleep(t);

    } catch (InterruptedException ex) {

      throw new RuntimeException(ex);

    }

  }



  /** {@link PTransform} that signals once when the pipeline has started. */

  static class PublishStart extends PTransform<PBegin, PDone> {

    private final TopicPath startTopicPath;



    PublishStart(TopicPath startTopicPath) {

      this.startTopicPath = startTopicPath;

    }



    @Override

    public PDone expand(PBegin input) {

      return input

          .apply("Start signal", Create.of(START_SIGNAL_MESSAGE))

          .apply(PubsubIO.writeStrings().to(startTopicPath.getPath()));

    }

  }



  /** {@link PTransform} that for validates whether elements seen so far match success criteria. */

  static class PublishSuccessWhen<T> extends PTransform<PCollection<? extends T>, POutput> {

    private final Coder<T> coder;

    private final SerializableFunction<T, String> formatter;

    private final SerializableFunction<Set<T>, Boolean> successPredicate;

    private final TopicPath resultTopicPath;



    PublishSuccessWhen(

        Coder<T> coder,

        SerializableFunction<T, String> formatter,

        SerializableFunction<Set<T>, Boolean> successPredicate,

        TopicPath resultTopicPath) {



      this.coder = coder;

      this.formatter = formatter;

      this.successPredicate = successPredicate;

      this.resultTopicPath = resultTopicPath;

    }



    @Override

    public POutput expand(PCollection<? extends T> input) {

      return input

          // assign a dummy key and global window,

          // this is needed to accumulate all observed events in the same state cell

          .apply(Window.into(new GlobalWindows()))

          .apply(WithKeys.of("dummyKey"))

          .apply(

              "checkAllEventsForSuccess",

              ParDo.of(new StatefulPredicateCheck<>(coder, formatter, successPredicate)))

          // signal the success/failure to the result topic

          .apply("publishSuccess", PubsubIO.writeStrings().to(resultTopicPath.getPath()));

    }

  }



  /**

   * Stateful {@link DoFn} which caches the elements it sees and checks whether they satisfy the

   * predicate.

   *

   * <p>When predicate is satisfied outputs "SUCCESS". If predicate throws execption, outputs

   * "FAILURE".

   */

  static class StatefulPredicateCheck<T> extends DoFn<KV<String, ? extends T>, String> {

    private final SerializableFunction<T, String> formatter;

    private SerializableFunction<Set<T>, Boolean> successPredicate;

    // keep all events seen so far in the state cell



    private static final String SEEN_EVENTS = "seenEvents";



    @StateId(SEEN_EVENTS)

    private final StateSpec<BagState<T>> seenEvents;



    StatefulPredicateCheck(

        Coder<T> coder,

        SerializableFunction<T, String> formatter,

        SerializableFunction<Set<T>, Boolean> successPredicate) {

      this.seenEvents = StateSpecs.bag(coder);

      this.formatter = formatter;

      this.successPredicate = successPredicate;

    }



    @ProcessElement

    public void processElement(

		ProcessContext context, @StateId(SEEN_EVENTS) BagState<T> seenEvents) {



      seenEvents.add(context.element().getValue());

      ImmutableSet<T> eventsSoFar = ImmutableSet.copyOf(seenEvents.read());



      // check if all elements seen so far satisfy the success predicate

      try {

        if (successPredicate.apply(eventsSoFar)) {

          context.output("SUCCESS");

        }

      } catch (Throwable e) {

        context.output("FAILURE: " + e.getMessage());

      }

    }

  }

}